{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from math import ceil\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from utils.input_pipeline import get_image_folders\n",
    "from utils.training import train, optimization_step\n",
    "    \n",
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "torch.backends.cudnn.benchmark = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create data iterators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100000"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_folder, val_folder = get_image_folders()\n",
    "\n",
    "train_iterator = DataLoader(\n",
    "    train_folder, batch_size=batch_size, num_workers=4,\n",
    "    shuffle=True, pin_memory=True\n",
    ")\n",
    "\n",
    "val_iterator = DataLoader(\n",
    "    val_folder, batch_size=256, num_workers=4,\n",
    "    shuffle=False, pin_memory=True\n",
    ")\n",
    "\n",
    "# number of training samples\n",
    "train_size = len(train_folder.imgs)\n",
    "train_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10000"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# number of validation samples\n",
    "val_size = len(val_folder.imgs)\n",
    "val_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from densenet import DenseNet\n",
    "model = DenseNet()\n",
    "# load the model from step2\n",
    "model.load_state_dict(torch.load('model_step2.pytorch_state'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# create different parameter groups\n",
    "weights = [\n",
    "    p for n, p in model.named_parameters()\n",
    "    if 'conv' in n or 'classifier.weight' in n\n",
    "]\n",
    "biases = [model.classifier.bias]\n",
    "bn_weights = [\n",
    "    p for n, p in model.named_parameters()\n",
    "    if 'norm' in n and 'weight' in n\n",
    "]\n",
    "bn_biases = [\n",
    "    p for n, p in model.named_parameters()\n",
    "    if 'norm' in n and 'bias' in n\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "params = [\n",
    "    {'params': weights, 'weight_decay': 1e-4},\n",
    "    {'params': biases},\n",
    "    {'params': bn_weights},\n",
    "    {'params': bn_biases}\n",
    "]\n",
    "optimizer = optim.SGD(params, lr=1e-4, momentum=0.9, nesterov=True)\n",
    "loss = nn.CrossEntropyLoss().cuda()\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1563"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_epochs = 5\n",
    "n_batches = ceil(train_size/batch_size)\n",
    "\n",
    "# total number of batches in the train set\n",
    "n_batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0  1.748 1.295  0.568 0.671  0.810 0.877  2096.644\n",
      "1  1.631 1.224  0.592 0.689  0.827 0.886  2080.292\n",
      "2  1.558 1.184  0.609 0.697  0.838 0.892  2084.515\n",
      "3  1.493 1.154  0.622 0.704  0.848 0.893  2083.987\n",
      "4  1.449 1.118  0.632 0.713  0.855 0.899  2083.928\n",
      "CPU times: user 2h 38min 32s, sys: 15min 17s, total: 2h 53min 50s\n",
      "Wall time: 2h 53min 49s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "def optimization_step_fn(model, loss, x_batch, y_batch):\n",
    "    return optimization_step(model, loss, x_batch, y_batch, optimizer)\n",
    "\n",
    "all_losses = train(\n",
    "    model, loss, optimization_step_fn,\n",
    "    train_iterator, val_iterator, n_epochs\n",
    ")\n",
    "# epoch logloss  accuracy    top5_accuracy time  (first value: train, second value: val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.cpu();\n",
    "torch.save(model.state_dict(), 'model_step3.pytorch_state')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
